#
# Copyright (C) [2024] Xingyun Integrated Circuit, Inc.
#
# GreenCode was a private technology asset of Xingyun Integrated Circuit， Inc （Confidential）
# Author: Shawn.Tan
# Date: 2025.10.28
#
# History: Initial Version 2025.10.28
#
#
import RandomUtils

import exception_handlers_test_utils
from base.Sequence import Sequence
from gpgpu.EnvGPGPU import EnvGPGPU
from gpgpu.GenThreadGPGPU import GenThreadGPGPU
from gpgpu.Utils import LoadGPR64


#  This test verifies that exceptions are triggered as expected and that
#  eneration and simulation can continue after the exceptions are handled. We
# reserve a selection of GPRs to test the exception dispatcher register save
# and restore logic. Reserving the registers prevents the random instructions
# generated by this test from modifying the registers. However, exception
# handlers may modify reserved registers, but must restore them before
# returning.
class MainSequence(Sequence):
    def __init__(self, gen_thread, name=None):
        super().__init__(gen_thread, name)

        self._mExceptCounts = {4: 0, 6: 0, 8: 0, 9: 0, 11: 0}

    def generate(self, **kargs):
        self._configureExceptionDelegation(self._mExceptCounts.keys())
        self._switchToRandomPrivilegeLevel()

        rv32 = self.getGlobalState("AppRegisterWidth") == 32

        orig_gpr_values = exception_handlers_test_utils.reserve_random_gprs(self, 20)
        ecall_except_codes = {0: 8, 1: 9, 3: 11}
        for _ in range(100):
            priv_level = self.getPEstate("PrivilegeLevel")
            ecall_except_code = ecall_except_codes[priv_level]
            except_code = self.choice((4, 6, ecall_except_code))

            if except_code == 4:
                unaligned_target_addr = self.genVA(Size=16, Align=8, Type="D") + 1
                if rv32:
                    self.genInstruction("LW##CALM", {"LSTarget": unaligned_target_addr})
                else:
                    self.genInstruction("LD##CALM", {"LSTarget": unaligned_target_addr})
                self._verifyExceptionCount(4)
                exception_handlers_test_utils.assert_gpr_values_unchanged(self, orig_gpr_values)
            elif except_code == 6:
                unaligned_target_addr = self.genVA(Size=16, Align=8, Type="D") + 1
                if rv32:
                    self.genInstruction("SW##CALM", {"LSTarget": unaligned_target_addr})
                else:
                    self.genInstruction("SD##CALM", {"LSTarget": unaligned_target_addr})
                self._verifyExceptionCount(6)
                exception_handlers_test_utils.assert_gpr_values_unchanged(self, orig_gpr_values)
            elif except_code == ecall_except_code:
                self.genInstruction("ECALL##CALM")
                self._verifyExceptionCount(ecall_except_code)
                exception_handlers_test_utils.assert_gpr_values_unchanged(self, orig_gpr_values)
            else:
                self.error("Unexpected exception code: %d" % except_code)

    # Randomly configure some of the specified exception codes to be handled
    # at a lower privilege level.
    #
    #  @param aExceptCodes A list of exception codes for configuring
    #       delegation.
    def _configureExceptionDelegation(self, aExceptCodes):
        # Spike appears to only allow exception codes 0, 3, 8, 12, 13 and 15
        # to be delegated. The logic below harmlessly ignores this constraint
        # for simplicity and in case Spike's implementation changes in the
        # future. There does not appear to be any provisions in the GPGPU
        # Privileged Architecture Specification that mandate such a
        # restrictive delegation scheme.
        medeleg_val = 0
        for except_code in aExceptCodes:
            medeleg_val |= RandomUtils.random32(0, 1) << except_code

        load_gpr64_seq = LoadGPR64(self.genThread)
        medeleg_val_reg_index = self.getRandomGPR(exclude="0")
        load_gpr64_seq.load(medeleg_val_reg_index, medeleg_val)

        # self.genInstruction(
        #   'CSRRW#register#CALM', {'rd': 0, 'rs1': medeleg_val_reg_index,
        #                            'csr': self.getRegisterIndex('medeleg')})

    # Switch to a random privilege level.
    def _switchToRandomPrivilegeLevel(self):
        self.genInstruction("MRET##CALM")

    # Verify the exception count for the specified exception code has
    # increased.
    #
    #  @param aExceptCode The code of the exception that should have been
    #       triggered.
    def _verifyExceptionCount(self, aExceptCode):
        new_except_count = self.queryExceptionRecordsCount(aExceptCode)
        if new_except_count > self._mExceptCounts[aExceptCode]:
            self._mExceptCounts[aExceptCode] = new_except_count
        else:
            self.error("An exception with code %d was not triggered" % aExceptCode)


MainSequenceClass = MainSequence
GenThreadClass = GenThreadGPGPU
EnvClass = EnvGPGPU
